{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data checks\n",
    "\n",
    "The best approach depends on your domain (don't assume that what everyone is doing to prep datasets for LLM pre-training is relevant for you).\n",
    "\n",
    "The examples below aren't aiming to be as efficient as possible but show what different approaches you can take at a high level and what they result in."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install datasets -qq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rules based checks\n",
    "\n",
    "You can use rules and simple functions both for filtering and deduplication of the data and for checking the quality of synthetic data generations. \n",
    "\n",
    "As you'd expect the rules that you define will be based on the domain and the data that you have."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Checking the quality of synthetic data\n",
    "\n",
    "The dataset below is a synthetic dataset generated which is aiming to use an LLM to generate tl;dr summaries of dataset cards. We have some instruction in our prompt which tell the model to not repeat information that is probably better as structured metadata. In the synthetic dataset generation pipeline, we use a judge LLM to check the quality of the generated summaries against the instructions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\n",
    "    \"davanstrien/dataset-preferences-llm-course-full-dataset\", split=\"train\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['datasetId', 'card', 'instruction', 'system_prompt', 'generation_models', 'generations', 'model_name', 'ratings', 'rationales'],\n",
       "    num_rows: 2482\n",
       "})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"\\n<instructions>\\nWrite a tl;dr summary of a dataset based on the dataset card. Focus on the most critical aspects of the dataset.\\n\\nThe summary should aim to concisely describe the dataset.\\n</instructions>\\n\\n<card>\\n\\nIn this dataset, you will find a collection of records that show a category, an instruction, a context and a response to that instruction. The aim of the project is to correct the instructions, intput and responses to make sure they are of the highest quality and that they match the task category that they belong to. All three texts should be clear and include real information. In addition, the response should be as complete but concise as possible.\\nTo curate the dataset, you will need to provide an answer to the following text fields:\\n1 - Final instruction:\\nThe final version of the instruction field. You may copy it using the copy icon in the instruction field. Leave it as it is if it's ok or apply any necessary corrections. Remember to change the instruction if it doesn't represent well the task category of the record.\\n2 - Final context:\\nThe final version of the instruction field. You may copy it using the copy icon in the context field. Leave it as it is if it's ok or apply any necessary corrections. If the task category and instruction don't need of an context to be completed, leave this question blank.\\n3 - Final response:\\nThe final version of the response field. You may copy it using the copy icon in the response field. Leave it as it is if it's ok or apply any necessary corrections. Check that the response makes sense given all the fields above.\\nYou will need to provide at least an instruction and a response for all records. If you are not sure about a record and you prefer not to provide a response, click Discard.\\n## Fields\\n* `id` is of type <class 'str'>\\n* `category` is of type <class 'str'>\\n* `original-instruction` is of type <class 'str'>\\n* `original-context` is of type <class 'str'>\\n* `original-response` is of type <class 'str'>\\n## Questions\\n* `new-instruction` : Write the final version of the instruction, making sure that it matches the task category. If the original instruction is ok, copy and paste it here.\\n* `new-context` : Write the final version of the context, making sure that it makes sense with the task category. If the original context is ok, copy and paste it here. If an context is not needed, leave this empty.\\n* `new-response` : Write the final version of the response, making sure that it matches the task category and makes sense for the instruction (and context) provided. If the original response is ok, copy and paste it here.\\n## Load with Argilla\\nTo load this dataset with Argilla, you'll just need to install Argilla as `pip install argilla --upgrade` and then use the following code:\\n```python\\nimport argilla as rg\\nds = rg.FeedbackDataset.from_huggingface('argilla/databricks-dolly-15k-curated-en')\\n```\\n## Load with Datasets\\nTo load this dataset with Datasets, you'll just need to install Datasets as `pip install datasets --upgrade` and then use the following code:\\n```python\\nfrom datasets import load_dataset\\nds = load_dataset('argilla/databricks-dolly-15k-curated-en')\\n```\\n\\n</card>\\n\\n<instructions>\\nIf the card provides the necessary information, say what the dataset can be used for.\\nYou do not need to mention that the dataset is hosted or available on the Hugging Face Hub.\\nDo not mention the license of the dataset.\\nDo not mention the number of examples in the training or test split.\\nOnly mention size if there is extensive discussion of the scale of the dataset in the dataset card.\\nDo not speculate on anything not explicitly mentioned in the dataset card.\\nIn general avoid references to the quality of the dataset i.e. don't use phrases like 'a high-quality dataset' in the summary.\\n</instructions>\\n\\n<One sentence summary>\""
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0]['instruction']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One of the instructions we give the LLM is not to give quality judgements about the dataset in the card i.e. don't say things like \"this is a high quality dataset\". Although the judge LLM will identify some of these we can also do a much cheaper and simpler check by using a rules based approach."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_is_high_quality(row):\n",
    "    generations = row['generations']\n",
    "    # check for Nones\n",
    "    if any(generation is None for generation in generations):\n",
    "        return False\n",
    "    return any(\"high quality\" or \"high-quality\" in generation for generation in generations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'datasetId': 'argilla/databricks-dolly-15k-curated-en',\n",
       " 'card': \"In this dataset, you will find a collection of records that show a category, an instruction, a context and a response to that instruction. The aim of the project is to correct the instructions, intput and responses to make sure they are of the highest quality and that they match the task category that they belong to. All three texts should be clear and include real information. In addition, the response should be as complete but concise as possible.\\nTo curate the dataset, you will need to provide an answer to the following text fields:\\n1 - Final instruction:\\nThe final version of the instruction field. You may copy it using the copy icon in the instruction field. Leave it as it is if it's ok or apply any necessary corrections. Remember to change the instruction if it doesn't represent well the task category of the record.\\n2 - Final context:\\nThe final version of the instruction field. You may copy it using the copy icon in the context field. Leave it as it is if it's ok or apply any necessary corrections. If the task category and instruction don't need of an context to be completed, leave this question blank.\\n3 - Final response:\\nThe final version of the response field. You may copy it using the copy icon in the response field. Leave it as it is if it's ok or apply any necessary corrections. Check that the response makes sense given all the fields above.\\nYou will need to provide at least an instruction and a response for all records. If you are not sure about a record and you prefer not to provide a response, click Discard.\\n## Fields\\n* `id` is of type <class 'str'>\\n* `category` is of type <class 'str'>\\n* `original-instruction` is of type <class 'str'>\\n* `original-context` is of type <class 'str'>\\n* `original-response` is of type <class 'str'>\\n## Questions\\n* `new-instruction` : Write the final version of the instruction, making sure that it matches the task category. If the original instruction is ok, copy and paste it here.\\n* `new-context` : Write the final version of the context, making sure that it makes sense with the task category. If the original context is ok, copy and paste it here. If an context is not needed, leave this empty.\\n* `new-response` : Write the final version of the response, making sure that it matches the task category and makes sense for the instruction (and context) provided. If the original response is ok, copy and paste it here.\\n## Load with Argilla\\nTo load this dataset with Argilla, you'll just need to install Argilla as `pip install argilla --upgrade` and then use the following code:\\n```python\\nimport argilla as rg\\nds = rg.FeedbackDataset.from_huggingface('argilla/databricks-dolly-15k-curated-en')\\n```\\n## Load with Datasets\\nTo load this dataset with Datasets, you'll just need to install Datasets as `pip install datasets --upgrade` and then use the following code:\\n```python\\nfrom datasets import load_dataset\\nds = load_dataset('argilla/databricks-dolly-15k-curated-en')\\n```\",\n",
       " 'instruction': \"\\n<instructions>\\nWrite a tl;dr summary of a dataset based on the dataset card. Focus on the most critical aspects of the dataset.\\n\\nThe summary should aim to concisely describe the dataset.\\n</instructions>\\n\\n<card>\\n\\nIn this dataset, you will find a collection of records that show a category, an instruction, a context and a response to that instruction. The aim of the project is to correct the instructions, intput and responses to make sure they are of the highest quality and that they match the task category that they belong to. All three texts should be clear and include real information. In addition, the response should be as complete but concise as possible.\\nTo curate the dataset, you will need to provide an answer to the following text fields:\\n1 - Final instruction:\\nThe final version of the instruction field. You may copy it using the copy icon in the instruction field. Leave it as it is if it's ok or apply any necessary corrections. Remember to change the instruction if it doesn't represent well the task category of the record.\\n2 - Final context:\\nThe final version of the instruction field. You may copy it using the copy icon in the context field. Leave it as it is if it's ok or apply any necessary corrections. If the task category and instruction don't need of an context to be completed, leave this question blank.\\n3 - Final response:\\nThe final version of the response field. You may copy it using the copy icon in the response field. Leave it as it is if it's ok or apply any necessary corrections. Check that the response makes sense given all the fields above.\\nYou will need to provide at least an instruction and a response for all records. If you are not sure about a record and you prefer not to provide a response, click Discard.\\n## Fields\\n* `id` is of type <class 'str'>\\n* `category` is of type <class 'str'>\\n* `original-instruction` is of type <class 'str'>\\n* `original-context` is of type <class 'str'>\\n* `original-response` is of type <class 'str'>\\n## Questions\\n* `new-instruction` : Write the final version of the instruction, making sure that it matches the task category. If the original instruction is ok, copy and paste it here.\\n* `new-context` : Write the final version of the context, making sure that it makes sense with the task category. If the original context is ok, copy and paste it here. If an context is not needed, leave this empty.\\n* `new-response` : Write the final version of the response, making sure that it matches the task category and makes sense for the instruction (and context) provided. If the original response is ok, copy and paste it here.\\n## Load with Argilla\\nTo load this dataset with Argilla, you'll just need to install Argilla as `pip install argilla --upgrade` and then use the following code:\\n```python\\nimport argilla as rg\\nds = rg.FeedbackDataset.from_huggingface('argilla/databricks-dolly-15k-curated-en')\\n```\\n## Load with Datasets\\nTo load this dataset with Datasets, you'll just need to install Datasets as `pip install datasets --upgrade` and then use the following code:\\n```python\\nfrom datasets import load_dataset\\nds = load_dataset('argilla/databricks-dolly-15k-curated-en')\\n```\\n\\n</card>\\n\\n<instructions>\\nIf the card provides the necessary information, say what the dataset can be used for.\\nYou do not need to mention that the dataset is hosted or available on the Hugging Face Hub.\\nDo not mention the license of the dataset.\\nDo not mention the number of examples in the training or test split.\\nOnly mention size if there is extensive discussion of the scale of the dataset in the dataset card.\\nDo not speculate on anything not explicitly mentioned in the dataset card.\\nIn general avoid references to the quality of the dataset i.e. don't use phrases like 'a high-quality dataset' in the summary.\\n</instructions>\\n\\n<One sentence summary>\",\n",
       " 'system_prompt': 'You are a helpful, respectful and honest assistant`. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\\nYour role is to write short tl;dr descriptions of datasets based on existing dataset cards',\n",
       " 'generation_models': ['meta-llama/Meta-Llama-3-70B-Instruct',\n",
       "  'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO'],\n",
       " 'generations': ['This dataset contains categorized records of instructions, contexts, and responses, curated to ensure clarity, accuracy, and relevance, suitable for training AI models to understand and respond to instructions effectively.',\n",
       "  'The Databricks-Dolly-15k-Curated dataset comprises corrected records of instructions, context, and responses that adhere to specific categories, aiming to ensure high-quality data and real information, which can be utilized to enhance natural language understanding and generation tasks.'],\n",
       " 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct',\n",
       " 'ratings': [5, 3],\n",
       " 'rationales': [\"The text concisely summarizes the dataset, focusing on its critical aspects, and meets all the restrictions. It accurately describes the dataset's content and potential use, without speculating or mentioning unnecessary information.\",\n",
       "  'The text partially complies with the instruction. It summarizes the dataset\\'s content and goal, but mentions \"high-quality data,\" which is not explicitly mentioned in the dataset card and goes against the restriction of avoiding references to the quality of the dataset. Additionally, it\\'s a bit wordy and could be more concise.']}"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset.filter(check_is_high_quality)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also don't want the generated summaries to mention the licence of the dataset as this is already in the structured metadata."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "for rationales in dataset[0]['rationales']:\n",
    "    for rationale in rationales:\n",
    "        if \"license\" in rationale:\n",
    "            print(rationale)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'This software is licensed under the MIT License.': True, 'Released under the GNU General Public License, version 3.': True, 'This code follows the Apache License 2.0 guidelines.': True, 'All rights reserved under a Proprietary License.': True, 'The project is licensed under Creative Commons CC-BY-SA 4.0.': True, 'The document was last updated on 2023-05-14.': True, 'This agreement was signed on January 1, 2024.': True, 'No relevant license or date information here.': False}\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "def check_license_text(text):\n",
    "    # List of regex patterns for common software licenses and date formats\n",
    "    license_patterns = [\n",
    "        r'\\bMIT License\\b',\n",
    "        r'\\bGNU General Public License\\b',\n",
    "        r'\\bGPL\\b',\n",
    "        r'\\bApache License\\b',\n",
    "        r'\\bBSD License\\b',\n",
    "        r'\\bMozilla Public License\\b',\n",
    "        r'\\bMPL\\b',\n",
    "        r'\\bCreative Commons\\b',\n",
    "        r'\\bCC-BY\\b',\n",
    "        r'\\bCC-BY-SA\\b',\n",
    "        r'\\bProprietary License\\b',\n",
    "        # Common date formats\n",
    "        r'\\b\\d{4}-\\d{2}-\\d{2}\\b',  # YYYY-MM-DD\n",
    "        r'\\b\\d{4}/\\d{2}/\\d{2}\\b',  # YYYY/MM/DD\n",
    "        r'\\b\\d{2}/\\d{2}/\\d{4}\\b',  # MM/DD/YYYY\n",
    "        r'\\b\\d{2}/\\d{2}/\\d{4}\\b',  # DD/MM/YYYY\n",
    "        r'\\b(?:January|February|March|April|May|June|July|August|September|October|November|December) \\d{1,2}, \\d{4}\\b'\n",
    "    ]\n",
    "\n",
    "    # Compile patterns into a single regex\n",
    "    combined_pattern = re.compile('|'.join(license_patterns), re.IGNORECASE)\n",
    "\n",
    "    # Check if any of the patterns match a substring within the text\n",
    "    if combined_pattern.search(text):\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "\n",
    "# Example usage\n",
    "texts = [\n",
    "    \"This software is licensed under the MIT License.\",\n",
    "    \"Released under the GNU General Public License, version 3.\",\n",
    "    \"This code follows the Apache License 2.0 guidelines.\",\n",
    "    \"All rights reserved under a Proprietary License.\",\n",
    "    \"The project is licensed under Creative Commons CC-BY-SA 4.0.\",\n",
    "    \"The document was last updated on 2023-05-14.\",\n",
    "    \"This agreement was signed on January 1, 2024.\",\n",
    "    \"No relevant license or date information here.\"\n",
    "]\n",
    "\n",
    "# Testing the function\n",
    "results = {text: check_license_text(text) for text in texts}\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also filter on things like mentions of train and test splits, as this is not relevant to the summary of the dataset card and is something we prompt the model to not include."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train': True, 'test': True, 'validation': True, 'numbers': []}\n"
     ]
    }
   ],
   "source": [
    "import re\n",
    "\n",
    "def check_train_test_val_splits(text):\n",
    "    \"\"\"\n",
    "    Function to check for mentions of train, test, and validation splits in the provided text.\n",
    "\n",
    "    Parameters:\n",
    "    text (str): The input text to check.\n",
    "\n",
    "    Returns:\n",
    "    dict: A dictionary indicating whether train, test, and validation splits are mentioned.\n",
    "    \"\"\"\n",
    "\n",
    "    # Define regex patterns for train, test, and validation splits\n",
    "    patterns = {\n",
    "        \"train\": re.compile(r'\\btrain(?:ing)?(?:\\s*set)?(?:\\s*split)?\\b', re.IGNORECASE),\n",
    "        \"test\": re.compile(r'\\btest(?:ing)?(?:\\s*set)?(?:\\s*split)?\\b', re.IGNORECASE),\n",
    "        \"validation\": re.compile(r'\\b(?:validation|val|dev)(?:\\s*set)?(?:\\s*split)?\\b', re.IGNORECASE),\n",
    "        \"numbers\": re.compile(r'\\b(?:\\d+|\\d*\\.\\d+)(?:\\s*k|K|%|\\s*percent)?\\b', re.IGNORECASE)\n",
    "    }\n",
    "\n",
    "    # Initialize results dictionary\n",
    "    results = {\n",
    "        \"train\": False,\n",
    "        \"test\": False,\n",
    "        \"validation\": False,\n",
    "        \"numbers\": []\n",
    "    }\n",
    "\n",
    "    # Check for matches in the text\n",
    "    for key, pattern in patterns.items():\n",
    "        if matches := pattern.findall(text):\n",
    "            if key == \"numbers\":\n",
    "                results[key].extend(matches)\n",
    "            else:\n",
    "                results[key] = True\n",
    "\n",
    "    return results\n",
    "\n",
    "# Example usage\n",
    "text = \"\"\"\n",
    "The dataset is split into a training set, a test set, and a validation set.\n",
    "\"\"\"\n",
    "\n",
    "result = check_train_test_val_splits(text)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SQL rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql_dataset = load_dataset(\"gretelai/synthetic_text_to_sql\", split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],\n",
       "    num_rows: 100000\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['What is the total volume of timber sold by each salesperson, sorted by salesperson?',\n",
       " 'List all the unique equipment types and their corresponding total maintenance frequency from the equipment_maintenance table.',\n",
       " 'How many marine species are found in the Southern Ocean?',\n",
       " 'What is the total trade value and average price for each trader and stock in the trade_history table?']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_dataset[:4]['sql_prompt']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Preventing certain kinds of input\n",
    "\n",
    "We may want to not allow users to add SQL to their prompts. This is something we can add at the application layer but we might also want to exclude this kind of prompt form our training data, especially since the model may appear to be doing better than it is if it is just copying the SQL from the prompt (and this won't be available in production)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['SELECT MemberID, Name, Age, Gender, AVG(WorkoutDuration) as AverageWorkoutDuration FROM Members JOIN Workouts ON Members.MemberID = Workouts.MemberID GROUP BY MemberID, Name, Age, Gender ORDER BY AverageWorkoutDuration DESC;',\n",
       " 'SELECT MemberID, AVG(Steps) as AverageSteps, AVG(Calories) as AverageCalories, AVG(HeartRate) as AverageHeartRate FROM Wearables GROUP BY MemberID ORDER BY AverageSteps DESC;',\n",
       " \"SELECT MemberID, COUNT(*) as WorkoutCountThisWeek FROM Workouts WHERE Date >= DATE_TRUNC('week', CURRENT_DATE) GROUP BY MemberID ORDER BY WorkoutCountThisWeek DESC;\",\n",
       " \"SELECT DISTINCT MemberID, Gender FROM Members WHERE State = 'CA';\",\n",
       " \"SELECT MemberID, WorkoutType, DATE_TRUNC('day', Date) as Day, MIN(Duration) as MinDurationPerDay FROM Workouts GROUP BY MemberID, WorkoutType, Day ORDER BY Day DESC;\"]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_dataset.filter(lambda x: \"SELECT\" in x['sql_prompt'])[:]['sql_prompt'][:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['SELECT MemberID, Name, Age, Gender, AVG(WorkoutDuration) as AverageWorkoutDuration FROM Members JOIN Workouts ON Members.MemberID = Workouts.MemberID GROUP BY MemberID, Name, Age, Gender ORDER BY AverageWorkoutDuration DESC;',\n",
       " 'SELECT MemberID, AVG(Steps) as AverageSteps, AVG(Calories) as AverageCalories, AVG(HeartRate) as AverageHeartRate FROM Wearables GROUP BY MemberID ORDER BY AverageSteps DESC;',\n",
       " \"SELECT MemberID, COUNT(*) as WorkoutCountThisWeek FROM Workouts WHERE Date >= DATE_TRUNC('week', CURRENT_DATE) GROUP BY MemberID ORDER BY WorkoutCountThisWeek DESC;\",\n",
       " \"SELECT DISTINCT MemberID, Gender FROM Members WHERE State = 'CA';\"]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_dataset.filter(lambda x: \"FROM\" in x['sql_prompt'])[:4]['sql_prompt']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Preventing certain kinds of output\n",
    "\n",
    "Similarly we may want to prevent the model from violating certain rules in the SQL it generates. We can do a more \"formal\" check or compilation of the SQL to check for this but we can also do a simple check for certain keywords or phrases that we don't want to appear in the generated SQL.\n",
    "\n",
    "![](https://imgs.xkcd.com/comics/exploits_of_a_mom.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "048ac629eec74be5b9de9ba320f46c43",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "['DROP TABLE vehicle_safety_testing;',\n",
       " 'DROP TABLE redundant_billing_data;',\n",
       " 'DROP TABLE conditions;',\n",
       " 'DROP VIEW young_engineers;']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_dataset.filter(lambda x: \"DROP\" in x['sql'])[:4]['sql']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can ask an LLM to help us identify things we might want to be checking for. In the past we might have been to lazy to write a lot of functions that will only capture very small amounts of rules but with the LLM we can do this much more easily."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](claude_sql.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We might want to filter out queries that rely on double quotes as this is not standard SQL and might cause issues in production."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def has_double_quoted_strings(sql):\n",
    "    pattern = r'\"[^\"]*\"'\n",
    "    matches = re.findall(pattern, sql)\n",
    "    return len(matches) > 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "acc751129c9e4fb786c9710ddc9021f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "['SELECT program_id FROM professional_development JOIN teachers t ON professional_development.attended_teacher_id = t.teacher_id GROUP BY program_id HAVING COUNT(DISTINCT t.teacher_id) > 1 AND t.state = \"NY\";',\n",
       " 'SELECT COUNT(*) FROM students s1 WHERE s1.mental_health_score > (SELECT AVG(s2.mental_health_score) FROM students s2 WHERE s2.state = \"CA\");',\n",
       " 'SELECT YEAR(completion_date) AS \"Completion Year\", COUNT(*) FROM economic_diversification WHERE project_status = \\'completed\\' GROUP BY YEAR(completion_date);',\n",
       " 'SELECT MIN(listing_price) AS \"Lowest Price\", MAX(listing_price) AS \"Highest Price\" FROM sustainable_urban WHERE city = \\'New York\\';',\n",
       " 'SELECT MIN(productivity) AS \"Minimum productivity\", MAX(productivity) AS \"Maximum productivity\" FROM productivity WHERE department = \\'environment\\' AND year = 2021;']"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_dataset.filter(lambda x: has_double_quoted_strings(x[\"sql\"]))['sql'][:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly, a HAVING without GROUPBY could be an interesting filter to apply."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def has_having_without_group_by(sql):\n",
    "    pattern = r'HAVING\\s+(?!.*\\bGROUP\\s+BY\\b)'\n",
    "    match = re.search(pattern, sql, re.IGNORECASE | re.DOTALL)\n",
    "    return match is not None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3d5adc55fda04519a9ab454efc20016c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "['SELECT policy FROM PolicyAnalysis GROUP BY policy HAVING COUNT(DISTINCT department) = (SELECT COUNT(DISTINCT department) FROM PolicyAnalysis);',\n",
       " \"SELECT COUNT(*) FROM factories WHERE sector = 'renewable energy' HAVING employee_count > 50;\",\n",
       " \"SELECT name FROM menu_items WHERE restaurant_id = 3 AND available_time < '17:00:00' GROUP BY name HAVING SUM(available) = 0;\",\n",
       " \"SELECT mission_name FROM SpaceMissions WHERE astronaut_nationality IN ('US', 'Canada') GROUP BY mission_name HAVING COUNT(DISTINCT astronaut_nationality) = 2;\",\n",
       " 'SELECT AVG(grant_avg) FROM (SELECT department, AVG(amount) AS grant_avg FROM research_grants GROUP BY department HAVING COUNT(*) >= 3) AS subquery;']"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_dataset.filter(lambda x: has_having_without_group_by(x[\"sql\"]))['sql'][:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lastly, we would like to avoid basic SUM queries without NULL handling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def has_sum_without_null_handling(sql):\n",
    "    pattern = r'SUM\\s*\\(\\s*\\w+\\s*\\)'\n",
    "    match = re.search(pattern, sql, re.IGNORECASE)\n",
    "    return match is not None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id GROUP BY salesperson_id, name ORDER BY total_volume DESC;',\n",
       " 'SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;',\n",
       " 'SELECT SUM(spending) FROM defense.eu_humanitarian_assistance WHERE year BETWEEN 2019 AND 2021;',\n",
       " \"SELECT SUM(fare) FROM bus_routes WHERE route_name = 'Green Line';\",\n",
       " \"SELECT SUM(Amount) AS TotalAssistance, Country FROM HumanitarianAssistance WHERE Organization NOT IN ('Government', 'Military') AND Year BETWEEN 2016 AND 2020 GROUP BY Country;\"]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_dataset.filter(lambda x: has_sum_without_null_handling(x[\"sql\"]))['sql'][:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Eventually, we can comine and apply all these rules to filter out any of the the data that doesn't meet the criteria of us as domain experts ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100000, 19238)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(sql_dataset), len(sql_dataset.filter(lambda x: has_double_quoted_strings(x[\"sql\"]) or has_having_without_group_by(x[\"sql\"]) or has_sum_without_null_handling(x[\"sql\"])))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
